-
Notifications
You must be signed in to change notification settings - Fork 497
Wrap sync + a2a in a custom op #1597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist() | ||
T_out = int(sum(out_splits)) | ||
y = x.new_empty((T_out,) + tuple(x.shape[1:])) | ||
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group, async_op=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turning async_op=True gives nan values in loss before and after this PR with or without AC. Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @xmfan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wheres the wait? Dynamo should graph break on async_op=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wheres the wait?
Ah this might be the issue, I guess we'd have to either wrap the outputs in an Async tensor or manually call wait on the usage site.
Dynamo should graph break on async_op=True
Dynamo doesn't see this since its in a custom op
out_splits, in_splits = out_splits.tolist(), in_splits.tolist() | ||
else: | ||
out_splits_cpu, in_splits_cpu = out_splits.to(device="cpu", non_blocking=True), in_splits.to(device="cpu", non_blocking=True) | ||
torch.cuda.current_stream().synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a difference between nonblocking .to followed by sync (what is done here), and just calling .to with nonblocking=False, which is supposed to call cuda stream sync. Only the former works here, but not sure why yet.
cc @ngimel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the error you are getting? Wrong results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally, "RuntimeError: Split sizes doesn't match total dim 0 size", but now I'm no longer able to reproduce it...
4f478d9
to
42c7343
Compare
42c7343
to
a078871
Compare
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) | ||
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync | ||
# if the a2a is saved. | ||
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should do this. Hiding these comms will prevent inductor passes from reordering around them... We won't be able to overlap shared experts with neither token dispatch and combine via the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two comments.
- Here it's grouping (1) get-splits-info a2a, (2) d2h sync, (3) token (un)permutation a2a into a single op. Since the shared expert overlapping is targeting (3), I wonder if we can just group (1) and (2) in a custom op and separately AC this custom op and (3)?
- Fwiw, DeepSeek V3 shared experts are small and a2a's are big (
topk=8
), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-Huang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb way of doing this is to just have two paths since this is an eager SAC optimization: eager will use custom op, and compile does not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the two problems are unrelated, although it looks like they both have to do with the a2a code in titan (my PR is related to a correctness issue that Ruisi ran into)
I'm reading through this PR, but can someone describe the problem in a bit more detail? The two things that I got so far from reading are:
(1) there is an all2all in the moe code that it sounds like we don't want to recompute (but for some reason we are when SAC is on?)
(2) there is a torch.cuda.current_stream().synchronize()
call in the token dispatch code, which compile is probably not handling very well today. And it looks like the current PR tries to throw it in a custom op as a workaroud? (at the cost of making the custom op a "custom collective" that inductor won't be able to optimize, as @xmfan mentioned)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah for more context, today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.
The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.
This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update:
@bdhirsh pointed to me that another alternative is to save the d2h op instead
I was not able to try this originally due to #1597 (comment)
but I'm no longer able to repro that!
So from discussion with @tianyu-l offline, the current plan is to do this instead of doing a custom op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried locally and it works for me!
Let's continue figuring out the strategy with compile.
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) | ||
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync | ||
# if the a2a is saved. | ||
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two comments.
- Here it's grouping (1) get-splits-info a2a, (2) d2h sync, (3) token (un)permutation a2a into a single op. Since the shared expert overlapping is targeting (3), I wonder if we can just group (1) and (2) in a custom op and separately AC this custom op and (3)?
- Fwiw, DeepSeek V3 shared experts are small and a2a's are big (
topk=8
), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-Huang
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) | ||
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync | ||
# if the a2a is saved. | ||
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.
The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.
This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.
Before PR (selective op AC, a2a not saved)
After PR (selective op AC, a2a is saved)
Only cudaStreamSync in the forward